{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Test XGB.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"S_4xAvfBqZE6","executionInfo":{"status":"ok","timestamp":1619337756178,"user_tz":-420,"elapsed":5558,"user":{"displayName":"Puttimeth Suppawawisit","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgiHK03V3vjpdHzuwpg2CL9CRKkjGhdB933ZMvfxA=s64","userId":"13505766524058892004"}},"outputId":"95324c84-5223-4655-ce92-ca532dbb4b97"},"source":["!pip install xgboost"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Requirement already satisfied: xgboost in /usr/local/lib/python3.7/dist-packages (0.90)\n","Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from xgboost) (1.4.1)\n","Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from xgboost) (1.19.5)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"OGMu2BILpRpl","executionInfo":{"status":"ok","timestamp":1619337778423,"user_tz":-420,"elapsed":21291,"user":{"displayName":"Puttimeth Suppawawisit","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgiHK03V3vjpdHzuwpg2CL9CRKkjGhdB933ZMvfxA=s64","userId":"13505766524058892004"}},"outputId":"2273effb-1a22-4179-8b5c-07077f77d8ee"},"source":["# import library\n","import cv2\n","import glob\n","import csv\n","import itertools\n","import numpy as np\n","import pandas as pd\n","import xgboost as xgb\n","from sklearn.neighbors import KNeighborsClassifier\n","from sklearn.metrics import precision_score\n","from google.colab import drive\n","# mount google drive, required passcode from chosen google account\n","drive.mount('/content/drive')"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Mounted at /content/drive\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"sFKjdaiypV_B"},"source":["# constant\n","drive_url = \"/content/drive/Shared drives/Computer Vision\"\n","dataset_url = \"/content/drive/Shared drives/Computer Vision/dataset\"\n","csv_url = \"/content/drive/Shared drives/Computer Vision/CSV/\"\n","# from 24 data\n","train_dataset_size = 18"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"b42iy8b6qwgb"},"source":["# prepare dataset\n","name_list = [\"1\",\"2\",\"3\",\"4\",\"5\",\"ก\",\"ง2\",\"จ1\",\"จ2\",\"ฉ1\",\"ซ2\",\"ด\",\"ต\",\"ท1\",\"น\",\"บ\",\"พ\",\"ฟ\",\"ม\",\"ย\",\"ร\",\"ล\",\"ว\",\"ส\",\"ห\",\"อ\"]\n","train_dataset = []\n","test_dataset = []\n","train_label = []\n","test_label = []\n","for idx, name in enumerate(name_list):\n"," # read csv\n"," df=pd.read_csv( csv_url + name + '.csv')\n"," t_all_data = df.to_numpy()[:,1:]\n"," # remove z column\n"," t_all_data = t_all_data[:,[True,True,False]*21]\n"," # split into train and test\n"," t_train_dataset = t_all_data[0:train_dataset_size]\n"," t_test_dataset = t_all_data[train_dataset_size:24]\n"," # create label and split\n"," t_label_data = np.full(24, idx)\n"," t_train_label = t_label_data[0:train_dataset_size]\n"," t_test_label = t_label_data[train_dataset_size:24]\n"," # add temp dataset to list\n"," train_dataset.append(t_train_dataset)\n"," test_dataset.append(t_test_dataset)\n"," train_label.append(t_train_label)\n"," test_label.append(t_test_label)\n","\n","train_dataset = np.concatenate(train_dataset)\n","test_dataset = np.concatenate(test_dataset)\n","train_label = np.concatenate(train_label)\n","test_label = np.concatenate(test_label)\n","# prepare xgb DMatrix\n","xgb_train = xgb.DMatrix(train_dataset, label=train_label)\n","xgb_test = xgb.DMatrix(test_dataset, label=test_label)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"YpVWgGOHsu08","executionInfo":{"status":"ok","timestamp":1619262178168,"user_tz":-420,"elapsed":1813,"user":{"displayName":"Puttimeth Suppawawisit","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgiHK03V3vjpdHzuwpg2CL9CRKkjGhdB933ZMvfxA=s64","userId":"13505766524058892004"}},"outputId":"51bc1162-2590-4557-dc19-833a1f679d3b"},"source":["# train parameter\n","xgb_params_tree = {\n"," 'eta':0.05,\n"," 'booster':\"gbtree\",\n"," 'max_depth':2,\n"," 'gamma':0,\n"," 'subsample':0.5,\n"," 'min_child_weight':1,\n"," 'colsample_bytree':0.01,\n"," 'objective':\"multi:softprob\",\n"," 'eval_metric':\"merror\",\n"," 'num_class':26\n","}\n","xgb_params_linear = {\n"," 'booster':\"gblinear\",\n"," 'objective':\"multi:softprob\",\n"," 'eval_metric':\"merror\",\n"," 'num_class':26\n","}\n","# training\n","xgb_model = xgb.train(params=xgb_params_linear, \\\n"," dtrain=xgb_train, \\\n"," num_boost_round=100, \\\n"," evals=[(xgb_train,'xgb_train'),(xgb_test,'xgb_test')], \\\n"," early_stopping_rounds=50)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["[0]\txgb_train-merror:0.185897\txgb_test-merror:0.262821\n","Multiple eval metrics have been passed: 'xgb_test-merror' will be used for early stopping.\n","\n","Will train until xgb_test-merror hasn't improved in 50 rounds.\n","[1]\txgb_train-merror:0.115385\txgb_test-merror:0.185897\n","[2]\txgb_train-merror:0.087607\txgb_test-merror:0.134615\n","[3]\txgb_train-merror:0.070513\txgb_test-merror:0.108974\n","[4]\txgb_train-merror:0.061966\txgb_test-merror:0.102564\n","[5]\txgb_train-merror:0.053419\txgb_test-merror:0.096154\n","[6]\txgb_train-merror:0.049145\txgb_test-merror:0.089744\n","[7]\txgb_train-merror:0.049145\txgb_test-merror:0.089744\n","[8]\txgb_train-merror:0.049145\txgb_test-merror:0.083333\n","[9]\txgb_train-merror:0.038462\txgb_test-merror:0.064103\n","[10]\txgb_train-merror:0.036325\txgb_test-merror:0.057692\n","[11]\txgb_train-merror:0.036325\txgb_test-merror:0.057692\n","[12]\txgb_train-merror:0.036325\txgb_test-merror:0.051282\n","[13]\txgb_train-merror:0.032051\txgb_test-merror:0.044872\n","[14]\txgb_train-merror:0.029915\txgb_test-merror:0.038462\n","[15]\txgb_train-merror:0.027778\txgb_test-merror:0.038462\n","[16]\txgb_train-merror:0.025641\txgb_test-merror:0.038462\n","[17]\txgb_train-merror:0.025641\txgb_test-merror:0.038462\n","[18]\txgb_train-merror:0.025641\txgb_test-merror:0.044872\n","[19]\txgb_train-merror:0.021368\txgb_test-merror:0.038462\n","[20]\txgb_train-merror:0.021368\txgb_test-merror:0.038462\n","[21]\txgb_train-merror:0.021368\txgb_test-merror:0.038462\n","[22]\txgb_train-merror:0.021368\txgb_test-merror:0.038462\n","[23]\txgb_train-merror:0.021368\txgb_test-merror:0.038462\n","[24]\txgb_train-merror:0.021368\txgb_test-merror:0.044872\n","[25]\txgb_train-merror:0.019231\txgb_test-merror:0.044872\n","[26]\txgb_train-merror:0.019231\txgb_test-merror:0.044872\n","[27]\txgb_train-merror:0.019231\txgb_test-merror:0.044872\n","[28]\txgb_train-merror:0.021368\txgb_test-merror:0.044872\n","[29]\txgb_train-merror:0.021368\txgb_test-merror:0.051282\n","[30]\txgb_train-merror:0.021368\txgb_test-merror:0.051282\n","[31]\txgb_train-merror:0.021368\txgb_test-merror:0.051282\n","[32]\txgb_train-merror:0.021368\txgb_test-merror:0.057692\n","[33]\txgb_train-merror:0.021368\txgb_test-merror:0.057692\n","[34]\txgb_train-merror:0.021368\txgb_test-merror:0.057692\n","[35]\txgb_train-merror:0.021368\txgb_test-merror:0.057692\n","[36]\txgb_train-merror:0.019231\txgb_test-merror:0.057692\n","[37]\txgb_train-merror:0.019231\txgb_test-merror:0.057692\n","[38]\txgb_train-merror:0.017094\txgb_test-merror:0.057692\n","[39]\txgb_train-merror:0.019231\txgb_test-merror:0.057692\n","[40]\txgb_train-merror:0.019231\txgb_test-merror:0.057692\n","[41]\txgb_train-merror:0.019231\txgb_test-merror:0.051282\n","[42]\txgb_train-merror:0.019231\txgb_test-merror:0.051282\n","[43]\txgb_train-merror:0.019231\txgb_test-merror:0.051282\n","[44]\txgb_train-merror:0.019231\txgb_test-merror:0.051282\n","[45]\txgb_train-merror:0.019231\txgb_test-merror:0.051282\n","[46]\txgb_train-merror:0.019231\txgb_test-merror:0.051282\n","[47]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[48]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[49]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[50]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[51]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[52]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[53]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[54]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[55]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[56]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[57]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[58]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[59]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[60]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[61]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[62]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[63]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","[64]\txgb_train-merror:0.014957\txgb_test-merror:0.051282\n","Stopping. Best iteration:\n","[14]\txgb_train-merror:0.029915\txgb_test-merror:0.038462\n","\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"WGt9rnWtO-Dn","executionInfo":{"status":"ok","timestamp":1619262186052,"user_tz":-420,"elapsed":1032,"user":{"displayName":"Puttimeth Suppawawisit","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgiHK03V3vjpdHzuwpg2CL9CRKkjGhdB933ZMvfxA=s64","userId":"13505766524058892004"}},"outputId":"588183c6-2365-4a87-e0e3-509ebe0de8e3"},"source":["# predict\n","# train dataset\n","preds = xgb_model.predict(xgb_train)\n","best_preds = np.asarray([np.argmax(line) for line in preds])\n","# print(preds)\n","# print(best_preds)\n","print(\"train set score:\",precision_score(train_label, best_preds, average='macro'))\n","# test dataset\n","preds = xgb_model.predict(xgb_test)\n","best_preds = np.asarray([np.argmax(line) for line in preds])\n","# print(preds)\n","# print(best_preds)\n","print(\"test set score:\",precision_score(test_label, best_preds, average='macro'))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["train set score: 0.9857142857142857\n","test set score: 0.9565018315018315\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MoVskJCQiFnT","executionInfo":{"status":"ok","timestamp":1619271659245,"user_tz":-420,"elapsed":1054,"user":{"displayName":"Puttimeth Suppawawisit","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgiHK03V3vjpdHzuwpg2CL9CRKkjGhdB933ZMvfxA=s64","userId":"13505766524058892004"}},"outputId":"9fc560bd-5a21-4a9e-ca83-622658d0939e"},"source":["# list wrong predict from best_preds\n","counter = 0\n","for name in name_list:\n"," for _ in range(6):\n"," if name_list[best_preds[counter]] != name:\n"," print('answer:',name,'predict: ',name_list[best_preds[counter]])\n"," counter += 1"],"execution_count":null,"outputs":[{"output_type":"stream","text":["answer: จ1 predict: ย\n","answer: จ1 predict: ย\n","answer: ด predict: 1\n","answer: ต predict: ท1\n","answer: บ predict: ฟ\n","answer: ม predict: จ2\n","answer: อ predict: ส\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"1wNl39KTib8W","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1619260769691,"user_tz":-420,"elapsed":7672,"user":{"displayName":"Puttimeth Suppawawisit","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgiHK03V3vjpdHzuwpg2CL9CRKkjGhdB933ZMvfxA=s64","userId":"13505766524058892004"}},"outputId":"1d9cf1de-fac7-4008-c974-e11da7fc2ccf"},"source":["# params tune\n","_booster = ['gblinear']\n","_eta = [0.01,0.05,0.1]\n","_max_depth = [2]\n","_gamma = [0]\n","_subsample = [0.3,0.5,0.7]\n","_min_child_weight = [1]\n","_num_boost_round = [100]\n","\n","# loop\n","for _b,_e,_ma,_g,_s,_mi,_n in itertools.product(_booster,_eta,_max_depth,_gamma,_subsample,_min_child_weight,_num_boost_round):\n"," # train parameter\n"," xgb_params = {\n"," 'eta':_e,\n"," 'booster':_b,\n"," 'max_depth':_ma,\n"," 'gamma':_g,\n"," 'subsample':_s,\n"," 'min_child_weight':_mi,\n"," 'colsample_bytree':0.01,\n"," 'objective':\"multi:softprob\",\n"," 'eval_metric':\"merror\",\n"," 'num_class':26,\n"," }\n"," # training\n"," xgb_model = xgb.train(params=xgb_params, \\\n"," dtrain=xgb_train, \\\n"," num_boost_round=_n, \\\n"," evals=[(xgb_train,'xgb_train'),(xgb_test,'xgb_test')], \\\n"," early_stopping_rounds=50, \\\n"," verbose_eval = False)\n"," # predict\n"," # train dataset\n"," preds = xgb_model.predict(xgb_train)\n"," best_preds = np.asarray([np.argmax(line) for line in preds])\n"," # print(preds)\n"," # print(best_preds)\n"," print(\"booster\",_b,\"eta\",_e,\"max_depth\",_ma,\"gamma\",_g,\"subsample\",_s,\"min_child_weight\",_mi,\"num_boost_round\",_n,\"\\ntrain set score:\",precision_score(train_label, best_preds, average='macro'))\n"," # test dataset\n"," preds = xgb_model.predict(xgb_test)\n"," best_preds = np.asarray([np.argmax(line) for line in preds])\n"," # print(preds)\n"," # print(best_preds)\n"," print(\"test set score:\",precision_score(test_label, best_preds, average='macro'))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["booster gblinear eta 0.01 max_depth 2 gamma 0 subsample 0.3 min_child_weight 1 num_boost_round 100 \n","train set score: 0.9217640194817085\n","test set score: 0.896565934065934\n","booster gblinear eta 0.01 max_depth 2 gamma 0 subsample 0.5 min_child_weight 1 num_boost_round 100 \n","train set score: 0.9217640194817085\n","test set score: 0.896565934065934\n","booster gblinear eta 0.01 max_depth 2 gamma 0 subsample 0.7 min_child_weight 1 num_boost_round 100 \n","train set score: 0.9217640194817085\n","test set score: 0.896565934065934\n","booster gblinear eta 0.05 max_depth 2 gamma 0 subsample 0.3 min_child_weight 1 num_boost_round 100 \n","train set score: 0.9615814611944642\n","test set score: 0.9546703296703297\n","booster gblinear eta 0.05 max_depth 2 gamma 0 subsample 0.5 min_child_weight 1 num_boost_round 100 \n","train set score: 0.9615814611944642\n","test set score: 0.9546703296703297\n","booster gblinear eta 0.05 max_depth 2 gamma 0 subsample 0.7 min_child_weight 1 num_boost_round 100 \n","train set score: 0.9615814611944642\n","test set score: 0.9546703296703297\n","booster gblinear eta 0.1 max_depth 2 gamma 0 subsample 0.3 min_child_weight 1 num_boost_round 100 \n","train set score: 0.9788461538461538\n","test set score: 0.9684065934065935\n","booster gblinear eta 0.1 max_depth 2 gamma 0 subsample 0.5 min_child_weight 1 num_boost_round 100 \n","train set score: 0.9788461538461538\n","test set score: 0.9684065934065935\n","booster gblinear eta 0.1 max_depth 2 gamma 0 subsample 0.7 min_child_weight 1 num_boost_round 100 \n","train set score: 0.9788461538461538\n","test set score: 0.9684065934065935\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"XsTZtY9ydJgw"},"source":["# save model\n","# xgb_model.save_model(drive_url + '/models/xgb_model_tree.model')\n","# load model\n","xgb_model = xgb.Booster()\n","xgb_model.load_model(drive_url + '/models/xgb_model_linear.model')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mEAddzA0QrBN","executionInfo":{"status":"ok","timestamp":1619337826634,"user_tz":-420,"elapsed":626,"user":{"displayName":"Puttimeth Suppawawisit","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgiHK03V3vjpdHzuwpg2CL9CRKkjGhdB933ZMvfxA=s64","userId":"13505766524058892004"}},"outputId":"ad2d47e8-0ef0-4598-a89a-d231c45820c0"},"source":["# use\n","name_list = [\"1\",\"2\",\"3\",\"4\",\"5\",\"ก\",\"ง2\",\"จ1\",\"จ2\",\"ฉ1\",\"ซ2\",\"ด\",\"ต\",\"ท1\",\"น\",\"บ\",\"พ\",\"ฟ\",\"ม\",\"ย\",\"ร\",\"ล\",\"ว\",\"ส\",\"ห\",\"อ\"]\n","print(test_dataset.shape, test_dataset)\n","xgb_test2 = xgb.DMatrix(test_dataset)\n","preds = xgb_model.predict(xgb_test2)\n","best_preds = np.asarray([np.argmax(line) for line in preds])\n","print(preds)\n","print(best_preds)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["(156, 42) [[0.7238335 1. 0.41014032 ... 0.71479389 0.89639278 0.73377014]\n"," [0.46485637 1. 0.14966042 ... 0.73136766 0.82685689 0.7684536 ]\n"," [0.76864963 1. 0.43177988 ... 0.60373442 0.89500268 0.65116277]\n"," ...\n"," [0.89486298 1. 0.52402015 ... 0.3539521 0.9556219 0.38818387]\n"," [0.86056473 1. 0.4758993 ... 0.3613922 0.97062143 0.41939643]\n"," [0.84320948 1. 0.47039529 ... 0.37663894 0.894317 0.3922176 ]]\n","[[9.84311283e-01 0.00000000e+00 4.15644472e-12 ... 7.69749661e-17\n"," 5.58959186e-07 1.50432175e-17]\n"," [9.48601246e-01 0.00000000e+00 7.00464399e-15 ... 3.35175993e-16\n"," 1.07347105e-07 3.65557402e-19]\n"," [8.10811639e-01 0.00000000e+00 4.36542746e-11 ... 2.12869522e-12\n"," 1.18375624e-06 3.79682552e-14]\n"," ...\n"," [2.81742321e-21 4.33554738e-40 4.83074858e-10 ... 2.66825737e-05\n"," 3.17918358e-10 9.99971390e-01]\n"," [4.07454242e-20 9.31639271e-41 2.89795743e-10 ... 1.00890808e-04\n"," 1.41603881e-10 9.99865174e-01]\n"," [2.14040718e-18 3.28105441e-35 9.95969884e-09 ... 8.04262527e-06\n"," 4.23281143e-09 9.99954104e-01]]\n","[ 0 0 0 0 0 0 1 1 1 1 1 1 2 2 2 2 2 2 3 3 3 3 3 3\n"," 4 4 4 4 4 4 5 5 5 5 5 5 6 6 6 6 6 6 7 7 7 7 19 19\n"," 8 8 8 8 8 8 9 9 9 9 9 9 10 10 10 10 10 10 11 0 11 11 11 11\n"," 12 12 12 13 12 12 13 13 13 13 13 13 14 14 14 14 14 14 15 15 15 15 15 17\n"," 16 16 16 16 16 16 17 17 17 17 17 17 18 8 18 18 18 18 19 19 19 19 19 19\n"," 20 20 20 20 20 20 21 21 21 21 21 21 22 22 22 22 22 22 23 23 23 23 23 23\n"," 24 24 24 24 24 24 23 25 25 25 25 25]\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"0TusP1ESB2r_"},"source":[""],"execution_count":null,"outputs":[]}]}